import argparse


def get_train_args():
    parser = argparse.ArgumentParser(description="训练CIFAR-10识别模型")
    parser.add_argument("--batch_size", type=int, default=64, help="训练批大小")
    parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
    parser.add_argument(
        "--download_dataset", action="store_true", help="是否只下载数据集"
    )
    parser.add_argument(
        "--show_datashape", action="store_true", help="是否只查看数据的形状"
    )
    parser.add_argument(
        "--quiet_save_model",
        action="store_true",
        help="是否静默保存模型",
    )
    parser.add_argument(
        "--save_model_name",
        type=str,
        default="model.pth",
        help="保存模型的文件名",
    )
    parser.add_argument(
        "--continue_train_model",
        action="store_true",
        help="是否继续训练模型",
    )
    parser.add_argument(
        "--continue_train_model_name",
        type=str,
        default="model.pth",
        help="加载对应文件名的模型继续训练",
    )
    return parser.parse_args()


def get_run_args():
    parser = argparse.ArgumentParser(description="运行CIFAR-10识别模型")
    parser.add_argument(
        "--model_name",
        type=str,
        default="model.pth",
        help="模型的文件名",
    )
    parser.add_argument(
        "--share", action="store_true", help="WebUI是否公开(绑定0.0.0.0)"
    )
    return parser.parse_args()
